import itertools
from dataclasses import asdict, dataclass
from functools import partial
from typing import Callable, List, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.benchmark as benchmark
from tabulate import tabulate
from torch.nn.attention.bias import CausalBias, CausalVariant
from torch.nn.parameter import Parameter
from tqdm import tqdm


def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
    # warmup
    for _ in range(5):
        func(*args, **kwargs)
    t0 = benchmark.Timer(
        stmt="func(*args, **kwargs)",
        globals={"args": args, "kwargs": kwargs, "func": func},
    )
    return t0.adaptive_autorange(min_run_time=0.1).median * 1e6


@dataclass(frozen=True)
class ExperimentConfig:
    batch_size: int
    num_heads: int
    q_seq_len: int
    k_seq_len: int
    embed_dim: int
    dtype: torch.dtype

    @property
    def head_dim(self) -> int:
        return self.embed_dim // self.num_heads

    def asdict(self):
        dict_obj = asdict(self)
        dict_obj["head_dim"] = self.head_dim
        return dict_obj


@dataclass(frozen=True)
class ExperimentResults:
    materialized_mask_time: float
    attn_mask_subclass_time: float

    def get_entries(self) -> List:
        return [
            f"{self.materialized_mask_time:2f}",
            f"{self.attn_mask_subclass_time:2f}",
        ]


@dataclass(frozen=True)
class Experiment:
    config: ExperimentConfig
    results: ExperimentResults

    def get_entries(self) -> List:
        return self.config.get_entries() + self.results.get_entries()


def generate_inputs(
    batch_size, q_sequence_length, kv_sequence_length, embed_dim, dtype, device
):
    q_shape = (batch_size, q_sequence_length, embed_dim)
    kv_shape = (batch_size, kv_sequence_length, embed_dim)

    make_q = partial(torch.rand, q_shape, device=device, dtype=dtype)
    make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
    return make_q(), make_kv(), make_kv()


class CompositeMHA(torch.nn.Module):
    def __init__(self, num_heads, embed_dim, device=None, dtype=None):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()

        self.head_dim = embed_dim // num_heads
        self.embed_dim = embed_dim
        assert (
            self.head_dim * num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"

        self.q_proj_weight = Parameter(
            torch.empty((embed_dim, embed_dim), **factory_kwargs)
        )
        self.k_proj_weight = Parameter(
            torch.empty((embed_dim, embed_dim), **factory_kwargs)
        )
        self.v_proj_weight = Parameter(
            torch.empty((embed_dim, embed_dim), **factory_kwargs)
        )
        self.out_proj = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
        self.num_heads = num_heads

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Union[torch.Tensor, CausalBias],
    ):
        query_projected = F.linear(query, self.q_proj_weight)
        key_projected = F.linear(key, self.k_proj_weight)
        value_projected = F.linear(value, self.v_proj_weight)

        query = query.view(
            query_projected.size(0), -1, self.num_heads, self.head_dim
        ).transpose(1, 2)
        key = key.view(
            key_projected.size(0), -1, self.num_heads, self.head_dim
        ).transpose(1, 2)
        value = value.view(
            value_projected.size(0), -1, self.num_heads, self.head_dim
        ).transpose(1, 2)

        attn = torch.nn.functional.scaled_dot_product_attention(
            query,
            key,
            value,
            attn_mask=mask,
            dropout_p=0.0,
        )

        attn = attn.transpose(1, 2).reshape(query.size(0), -1, self.embed_dim)
        # Match return signature of nn.MHA
        return F.linear(attn, self.out_proj)

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.q_proj_weight)
        nn.init.xavier_uniform_(self.k_proj_weight)
        nn.init.xavier_uniform_(self.v_proj_weight)
        nn.init.constant_(self.out_proj, 0.0)


def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
    device = torch.device("cuda")
    composite_mha = CompositeMHA(
        config.num_heads, config.embed_dim, device, config.dtype
    )
    composite_mha.reset_parameters()
    query, key, value = generate_inputs(
        config.batch_size,
        config.q_seq_len,
        config.k_seq_len,
        config.embed_dim,
        config.dtype,
        device,
    )
    attn_mask = CausalBias(
        CausalVariant.LOWER_RIGHT, config.q_seq_len, config.k_seq_len
    )
    attn_mask_tensor = attn_mask._materialize(device)

    materialized_mask_time = benchmark_torch_function_in_microseconds(
        composite_mha, query, key, value, attn_mask_tensor
    )
    attn_mask_subclass_time = benchmark_torch_function_in_microseconds(
        composite_mha, query, key, value, attn_mask
    )
    torch.testing.assert_close(
        composite_mha(query, key, value, attn_mask_tensor),
        composite_mha(query, key, value, attn_mask),
    )

    return ExperimentResults(
        materialized_mask_time=materialized_mask_time,
        attn_mask_subclass_time=attn_mask_subclass_time,
    )


def generate_experiment_configs() -> List[ExperimentConfig]:
    batch_sizes = [1, 8, 16, 128]
    num_heads = [16, 32]
    q_kv_seq_lens = [(128, 256), (256, 416), (512, 4097), (1024, 2048), (1, 2048)]
    embed_dims = [2048, 4096]
    dtypes = [
        torch.bfloat16,
    ]
    all_configs = []
    for bsz, heads, (q_seq_len, kv_seq_len), embed_dim, dtype in itertools.product(
        batch_sizes, num_heads, q_kv_seq_lens, embed_dims, dtypes
    ):
        all_configs.append(
            ExperimentConfig(
                batch_size=bsz,
                num_heads=heads,
                q_seq_len=q_seq_len,
                k_seq_len=kv_seq_len,
                embed_dim=embed_dim,
                dtype=dtype,
            )
        )

    return all_configs


def calculate_speedup(results: ExperimentResults) -> float:
    return results.materialized_mask_time / results.attn_mask_subclass_time


def print_results(results: List[Experiment]):
    # Calculate speedups
    speedups = [calculate_speedup(r.results) for r in results]

    # Find indices of max and min speedups
    max_speedup_index = np.argmax(speedups)
    min_speedup_index = np.argmin(speedups)

    # Get the config dictionaries
    max_config_dict = results[max_speedup_index].config.asdict()
    min_config_dict = results[min_speedup_index].config.asdict()

    # Create table data
    table_data = [
        {
            "Type": "Average",
            "Speedup": np.mean(speedups),
            **dict.fromkeys(max_config_dict),
        },
        {"Type": "Max", "Speedup": speedups[max_speedup_index], **max_config_dict},
        {"Type": "Min", "Speedup": speedups[min_speedup_index], **min_config_dict},
    ]

    # Print table
    print(tabulate(table_data, headers="keys", tablefmt="pretty"))


def main():
    seed = 123
    np.random.seed(seed)
    torch.manual_seed(seed)
    results = []
    # Run one timing experiment comparing nn_mha vs composite_mha
    for config in tqdm(generate_experiment_configs()):
        results.append(Experiment(config, run_single_experiment(config)))

    print_results(results)


if __name__ == "__main__":
    main()
